# Clear
rm(list = ls())

# Set WD
setwd("conflict_prediction_replication_pkg")

# Load Packages
source("config_code/directories.R")
source("config_code/package_requirements.R")
source("config_code/helper_functions.R")
source("config_code/model_parameters.R")

cl <- makeCluster(10)
registerDoParallel(cl)

args <- commandArgs()

table <- args[7]

if (is.element(table, c('tableA2'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) { # {any, high, spike}
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      
      for (rhs.group in names(get(table))) { 
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/appendix/min_mse/conflict_",algo,"_mse.R",sep = ""))
        } 
      }
    }
  }
}


if (is.element(table, c('tableA3_pt1',
                        'tableA3_pt2',
                        'tableA3_pt3',
                        'tableA3_pt4',
                        'tableA3_pt5',
                        'figure_A1_fixed',
                        'figure_A1_annual',
                        'figure_A1_slow'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c("count")) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/appendix/count/conflict_",algo,"_count.R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableA4'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c('demean')) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      tableA4 <- table1
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/appendix/demean/conflict_",algo,"_demean.R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableA5'))) {
  for (country in c('indo')) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableA7',
                        'figure_A3_fixed',
                        'figure_A3_annual',
                        'figure_A3_slow'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      source(paste("config_code/onset_filter.R",sep = ""))
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableA8', 'figure_A4_annual',
                        'figure_A4_fixed', 'figure_A4_slow'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      tableA8 = table1
      figure_A4_annual = figure_1_annual
      figure_A4_fixed = figure_1_fixed
      figure_A4_slow = figure_1_slow
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('figure_A5'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableA9', 'figure_A8_annual', 'figure_A8_fixed'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/colo_ext_dependent_vars.R",sep = ""))
      source(paste("config_code/colo_ext_predictor_vars.R",sep = ""))
      tableA9 <- table1
      figure_A8_annual <- figure_1_annual
      figure_A8_fixed <- figure_1_fixed
      for (rhs.group in names(get(table))) {
        
        rhs <- get(table)[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}


if (is.element(table, c('tableB1'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      cv.runs <- 1
      fileext <- "_1cv"
      for (rhs.group in names(get("table1"))) {
        
        rhs <- get("table1")[[rhs.group]]
        for (algo in c("lasso",
                       "gbm",
                       "rf",
                       "nn",
                       "ebma"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/main/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}

if (is.element(table, c('tableB2'))) {
  for (country in c(args[6])) { # {indo, colo}
    for (v in c(args[8])) {
      source(paste("config_code/",country,"_data_setup.R",sep = ""))
      source(paste("config_code/",country,"_dependent_vars.R",sep = ""))
      source(paste("config_code/",country,"_predictor_vars.R",sep = ""))
      for (rhs.group in names(get("table1"))) {
        
        rhs <- get("table1")[[rhs.group]]
        for (algo in c("nn2",
                       "nn3",
                       "nn4",
                       "nn5"
        )) {
          print(paste(v,rhs.group,algo,sep="---"))
          source(paste("estimation_code/appendix/nn_alternatives/conflict_",algo,".R",sep = ""))
        } 
      }
    }
  }
}